package mgrpc

import (
	"crypto/md5"
	"crypto/rand"
	"crypto/tls"
	"encoding/hex"
	"encoding/json"
	"fmt"
	"math/big"
	"net"
	"sync"
	"time"

	"golang.org/x/net/context"

	"golang.org/x/net/websocket"

	"cesanta.com/common/go/mgrpc/codec"
	"cesanta.com/common/go/mgrpc/frame"
	"github.com/cesanta/errors"
	"github.com/golang/glog"
)

const (
	authTypeDigest = "digest"
)

type GetCredsCallback func() (username, passwd string, err error)
type Handler func(MgRPC, *frame.Frame) *frame.Frame

type MgRPC interface {
	Call(
		ctx context.Context, dst string, cmd *frame.Command, getCreds GetCredsCallback,
	) (*frame.Response, error)
	AddHandler(method string, handler Handler)
	Disconnect(ctx context.Context) error
	IsConnected() bool
	SetCodecOptions(opts *codec.Options) error
}

type mgRPCImpl struct {
	codec codec.Codec

	// Map of outgoing requests, and its lock
	reqs     map[int64]req
	reqsLock sync.Mutex

	// Map of handlers, and its lock
	handlers     map[string]Handler
	handlersLock sync.Mutex

	opts *connectOptions

	closing bool
}

type req struct {
	respChan chan *frame.Response
	errChan  chan error
}

type authErrorMsg struct {
	AuthType string `json:"auth_type"`
	Nonce    int    `json:"nonce"`
	NC       int    `json:"nc"`
	Realm    string `json:"realm"`
}

const tcpKeepAliveInterval = 3 * time.Minute

// ErrorResponse is an error type for failed commands. Intended for use by
// wrappers around Call() method, like ones generated by clubbygen.
type ErrorResponse struct {
	// Status is the numerical status code.
	Status int
	// Msg is a human-readable description of the error.
	Msg string
}

func (e ErrorResponse) Error() string {
	return fmt.Sprintf("(%d) %s", e.Status, e.Msg)
}

func New(ctx context.Context, connectAddr string, opts ...ConnectOption) (MgRPC, error) {

	opts = append(opts, connectTo(connectAddr))

	rpc := mgRPCImpl{
		reqs: make(map[int64]req),
	}
	if err := rpc.connect(ctx, opts...); err != nil {
		return nil, errors.Trace(err)
	}

	go rpc.recvLoop(ctx, rpc.codec)

	return &rpc, nil
}

func Serve(ctx context.Context, c codec.Codec) MgRPC {
	rpc := mgRPCImpl{
		reqs:     make(map[int64]req),
		handlers: make(map[string]Handler),
		codec:    c,
		opts:     &connectOptions{localID: ""},
	}
	go rpc.recvLoop(ctx, rpc.codec)
	return &rpc
}

// wsDialConfig does the same thing as websocket.DialConfig, but also enables
// TCP keep-alive.
func wsDialConfig(config *websocket.Config) (*websocket.Conn, error) {
	host, port, err := net.SplitHostPort(config.Location.Host)
	if err != nil {
		// Assuming that no port specified.
		host = config.Location.Host
		port = ""
	}

	switch config.Location.Scheme {
	case "ws":
		if port == "" {
			port = "80"
		}
	case "wss":
		if port == "" {
			port = "443"
		}
	default:
		return nil, errors.Trace(websocket.ErrBadScheme)
	}
	addr, err := net.ResolveTCPAddr("tcp", net.JoinHostPort(host, port))
	if err != nil {
		return nil, errors.Annotate(err, "net.ResolveTCPAddr")
	}
	tc, err := net.DialTCP("tcp", nil, addr)
	if err != nil {
		return nil, errors.Annotate(err, "net.DialTCP")
	}
	tc.SetKeepAlive(true)
	tc.SetKeepAlivePeriod(tcpKeepAliveInterval)
	var nc net.Conn = tc

	if config.Location.Scheme == "wss" {
		nc = tls.Client(nc, config.TlsConfig)
	}

	conn, err := websocket.NewClient(config, nc)
	return conn, errors.Trace(err)
}

func (r *mgRPCImpl) AddHandler(method string, handler Handler) {
	r.handlers[method] = handler
}

func (r *mgRPCImpl) mqttConnect(dst string, opts *connectOptions) (codec.Codec, error) {
	return codec.MQTT(dst, opts.tlsConfig, &opts.codecOptions.MQTT)
}

func (r *mgRPCImpl) wsConnect(url string, opts *connectOptions) (codec.Codec, error) {
	// TODO(imax): figure out what we should use as origin and what to check on the server side.
	const origin = "https://api.cesanta.com/"
	config, err := websocket.NewConfig(url, origin)
	if err != nil {
		return nil, errors.Trace(err)
	}
	config.Protocol = []string{codec.WSProtocol}
	config.TlsConfig = opts.tlsConfig
	conn, err := wsDialConfig(config)
	if err != nil {
		return nil, errors.Trace(err)
	}
	return codec.WebSocket(conn), nil
}

func (r *mgRPCImpl) tcpConnect(tcpAddress string, opts *connectOptions) (codec.Codec, error) {
	// TODO(imax): add TLS support.
	conn, err := net.Dial("tcp", tcpAddress)
	if err != nil {
		return nil, errors.Trace(err)
	}
	conn.(*net.TCPConn).SetKeepAlive(true)
	conn.(*net.TCPConn).SetKeepAlivePeriod(tcpKeepAliveInterval)
	return codec.TCP(conn), nil
}

func (r *mgRPCImpl) serialConnect(
	ctx context.Context, portName string, opts *connectOptions,
) (codec.Codec, error) {
	sc, err := codec.Serial(ctx, portName, &opts.codecOptions.Serial)
	if err != nil {
		return nil, errors.Trace(err)
	}
	return sc, nil
}

func (r *mgRPCImpl) connect(ctx context.Context, opts ...ConnectOption) error {
	r.opts = &connectOptions{}

	for _, opt := range opts {
		if err := opt(r.opts); err != nil {
			return err
		}
	}

	glog.V(1).Infof("Connecting to %s over %s", r.opts.connectAddress, r.opts.proto)

	switch r.opts.proto {

	case tHTTP_POST:
		r.codec = codec.OutboundHTTP(r.opts.connectAddress, r.opts.tlsConfig)
	case tWebSocket:
		r.codec = codec.NewReconnectWrapperCodec(
			r.opts.connectAddress,
			func(wsURL string) (codec.Codec, error) {
				c, err := r.wsConnect(wsURL, r.opts)
				return c, errors.Trace(err)
			})
	case tMQTT:
		r.codec = codec.NewReconnectWrapperCodec(
			r.opts.connectAddress,
			func(mqttURL string) (codec.Codec, error) {
				c, err := r.mqttConnect(mqttURL, r.opts)
				return c, errors.Trace(err)
			})
	case tPlainTCP:
		r.codec = codec.NewReconnectWrapperCodec(
			r.opts.connectAddress,
			func(tcpAddress string) (codec.Codec, error) {
				c, err := r.tcpConnect(tcpAddress, r.opts)
				return c, errors.Trace(err)
			})
	case tSerial:
		if r.opts.enableReconnect {
			r.codec = codec.NewReconnectWrapperCodec(
				r.opts.connectAddress,
				func(serialAddress string) (codec.Codec, error) {
					c, err := r.serialConnect(ctx, serialAddress, r.opts)
					return c, errors.Trace(err)
				})
		} else {
			serialCodec, err := r.serialConnect(ctx, r.opts.connectAddress, r.opts)
			if err != nil {
				return errors.Trace(err)
			}
			r.codec = serialCodec
		}

	default:
		return fmt.Errorf("unknown transport %q", r.opts.proto)
	}

	return nil
}

func (r *mgRPCImpl) Disconnect(ctx context.Context) error {
	r.closing = true
	r.codec.Close()
	return nil
}

func sendErrorResponse(r MgRPC, f *frame.Frame) *frame.Frame {
	return &frame.Frame{
		ID:    f.ID,
		Error: &frame.Error{Code: 404, Message: fmt.Sprintf("Method [%s] not found", f.Method)},
	}
}

func (r *mgRPCImpl) recvLoop(ctx context.Context, c codec.Codec) {
	glog.V(2).Infof("Started recv loop, codec: %v", c)
	for {
		glog.V(2).Infof("recv ...")
		f, err := c.Recv(ctx)
		glog.V(2).Infof("done, %v", err)
		if r.closing {
			glog.Infof("devConn is disconnected, breaking out of the recvLoop", err)
			r.reqsLock.Lock()
			for k, v := range r.reqs {
				v.errChan <- err
				delete(r.reqs, k)
			}
			r.reqsLock.Unlock()
			r.Disconnect(ctx)
			return
		}
		if err != nil {
			glog.Infof("error returned from codec Recv: %s, keep trying", err)
			r.closing = true
			continue
		}

		if glog.V(2) {
			s := fmt.Sprintf("%+v", f)
			if len(s) > 1024 {
				s = fmt.Sprintf("%s... (%d)", s[:1024], len(s))
			}
			glog.V(2).Infof("Rec'd %s", s)
		}

		if f.Method != "" {
			callback := sendErrorResponse
			for k, v := range r.handlers {
				if k == f.Method {
					callback = v
					break
				}
			}
			resp := callback(r, f)
			if !f.NoResponse {
				c.Send(ctx, resp)
			}
			continue
		}

		resp := frame.NewResponseFromFrame(f)
		r.reqsLock.Lock()
		if req, ok := r.reqs[resp.ID]; ok {
			req.respChan <- resp
			delete(r.reqs, resp.ID)
		} else {
			glog.Infof("ignoring unsolicited response: %v", resp)
		}
		r.reqsLock.Unlock()
	}
}

func (r *mgRPCImpl) Call(
	ctx context.Context, dst string, cmd *frame.Command, getCreds GetCredsCallback,
) (*frame.Response, error) {
	if cmd.ID == 0 {
		cmd.ID = frame.CreateCommandUID()
	}

	respChan := make(chan *frame.Response)
	errChan := make(chan error)

	r.reqsLock.Lock()
	r.reqs[cmd.ID] = req{
		respChan: respChan,
		errChan:  errChan,
	}
	r.reqsLock.Unlock()
	glog.V(2).Infof("created a request with id %d", cmd.ID)

	f := frame.NewRequestFrame(r.opts.localID, dst, "", cmd)
	if err := r.codec.Send(ctx, f); err != nil {
		return nil, errors.Trace(err)
	}

	select {
	case resp := <-respChan:
		glog.V(2).Infof("got response to request %d: [%v] (%v)", cmd.ID, resp, resp.StatusMsg)
		if resp.Status == 401 && cmd.Auth == nil {
			var authMsg authErrorMsg
			if err := json.Unmarshal([]byte(resp.StatusMsg), &authMsg); err == nil {
				// Succeed in parsing error message, let's check auth type
				switch authMsg.AuthType {
				case authTypeDigest:
					// Get username and password
					username, passwd, err := getCreds()
					if err != nil {
						return nil, errors.Trace(err)
					}

					// Generate cnonce
					cnonceBig, err := rand.Int(rand.Reader, big.NewInt(0xffffffff))
					if err != nil {
						return nil, errors.Annotatef(err, "generating cnonce")
					}

					cnonce := int(cnonceBig.Int64())

					// Compute resp
					resp := mkMd5Resp(
						"dummy_method", "dummy_uri", username, authMsg.Realm, passwd,
						authMsg.Nonce, authMsg.NC, cnonce, "auth",
					)

					cmdWithAuth := *cmd
					cmdWithAuth.Auth = &frame.FrameAuth{
						Realm:    authMsg.Realm,
						Nonce:    authMsg.Nonce,
						Username: username,
						CNonce:   cnonce,
						Response: resp,
					}
					glog.V(2).Infof("resending cmd %d with auth added: %v", cmd.ID, cmdWithAuth)
					return r.Call(ctx, dst, &cmdWithAuth, getCreds)

				default:
					glog.Warningf("got 401 with an unknown auth_type: %v", authMsg.AuthType)
				}
			} else {
				glog.Warningf("got 401 with an invalid message: %v", resp.StatusMsg)
			}
		}
		return resp, nil
	case err := <-errChan:
		glog.V(2).Infof("got err on request %d: [%v]", cmd.ID, err)
		return nil, errors.Trace(err)
	case <-ctx.Done():
		glog.V(2).Infof("context for the request %d is done: %v", cmd.ID, ctx.Err())
		r.reqsLock.Lock()
		delete(r.reqs, cmd.ID)
		r.reqsLock.Unlock()
		return nil, errors.Trace(ctx.Err())
	}
}

func (r *mgRPCImpl) SendHello(dst string) {
	hello := &frame.Command{
		Cmd: "/v1/Hello",
	}
	glog.V(2).Infof("Sending hello to %q", dst)
	resp, err := r.Call(context.Background(), dst, hello, nil)
	glog.V(2).Infof("Hello response: %+v, %s", resp, err)
}

func (r *mgRPCImpl) IsConnected() bool {
	info := r.codec.Info()
	return info.IsConnected
}

func (r *mgRPCImpl) SetCodecOptions(opts *codec.Options) error {
	return r.codec.SetOptions(opts)
}

func mkMd5Resp(method, uri, username, realm, passwd string, nonce, nc, cnonce int, qop string) string {
	ha1Arr := md5.Sum([]byte(fmt.Sprintf("%s:%s:%s", username, realm, passwd)))
	ha1 := hex.EncodeToString(ha1Arr[:])

	ha2Arr := md5.Sum([]byte(fmt.Sprintf("%s:%s", method, uri)))
	ha2 := hex.EncodeToString(ha2Arr[:])

	respArr := md5.Sum([]byte(fmt.Sprintf(
		"%s:%d:%d:%d:%s:%s",
		ha1, nonce, nc, cnonce, "auth", ha2,
	)))
	resp := hex.EncodeToString(respArr[:])

	return resp
}
